import pickle
import os.path as osp

import torch as th
import numpy as np
import clip
from tqdm import tqdm

from diffgro.environments import make_env
from diffgro.environments.variant import *
from diffgro.utils import *


def add_noise(action: np.ndarray, std: float = 0.001):
    noise = np.random.normal(0.0, std, size=action.shape)  # noise
    noise[-1] = 0.0  # grab
    return action + noise


def compensate_environment_dynamics(action: np.ndarray, env):
    # [arm_speed]
    action[:3] = action[:3] * env.arm_speed
    # [wind_xspeed]
    action[0] = action[0] * env.wind_xspeed
    # [wind_yspeed]
    action[1] = action[1] * env.wind_yspeed
    return action


def check_action(skill: str, action: np.ndarray):
    if (
        action[0] > 1.0
        or action[0] < -1.0
        or action[1] > 1.0
        or action[1] < -1.0
        or action[2] > 1.0
        or action[2] < -1.0
    ):
        print_r(f"{skill} should be clipped {action}")


def get_skill_embed(lm, skill: str):
    if lm is None:
        lm, _ = clip.load("ViT-B/16")
    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    # skill embedding thorugh clip
    token = clip.tokenize(skill.replace("_", " ")).to(device)
    with th.no_grad():
        embed = lm.encode_text(token).squeeze()
    embed = embed.cpu().numpy()
    return embed


def run_environment_loop(
    env,
    lm,
    early_stop: str = True,
    video: str = False,
    warmup: bool = True,
    verbose: bool = False,
):
    domain = env.domain_name
    task = env.env_name
    task_embed = get_skill_embed(lm, task)  # task embedding

    trajectory = {
        "task": task_embed,
        "observations": [],
        "actions": [],
        "skill_langs": [],  # language
        "skill_embds": [],  # embeddings
        "rewards": [],
        "terminals": [],
        "infos": [],
    }

    frames = []
    obs, done = env.reset(warmup=warmup), False
    exp = env.get_exp()

    while not done:
        action, skill = exp.get_action(obs, return_skill=True)
        if verbose:
            check_action(skill, action)
        if domain == 'metaworld_complex':
            skill = env.full_task_list[env.success_count]

        # compensate for environment dynamics
        action = compensate_environment_dynamics(action, env)
        # add action noise
        action = add_noise(action)
        # action cliping
        action = np.clip(action, -1, 1)

        next_obs, reward, done, info = env.step((action, skill))

        trajectory["observations"].append(obs)
        trajectory["actions"].append(action)
        trajectory["skill_langs"].append(skill)
        embed = get_skill_embed(lm, skill)  # skill embedding
        trajectory["skill_embds"].append(embed)
        trajectory["rewards"].append(reward)
        trajectory["terminals"].append(done or info["success"])
        trajectory["infos"].append(info)

        obs = next_obs

        if video:
            frame = env.render()
            frame = write_annotation(frame, skill)
            frames.append(frame)

        if early_stop and info["success"]:
            break

    trajectory["observations"].append(obs)  # append last obs
    for key, value in trajectory.items():
        trajectory[key] = np.array(value)

    return trajectory, frames


def collect_dataset(args):
    print_r(f"<< Collecting Dataset for {args.env_name}... >>")

    domain_name, env_name = args.env_name.split(".")
    env = make_env(domain_name, env_name, seed=args.seed)
    print_b(f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}")

    # Set variant configuration, will be randomly sampled on reset
    if "variant" in env_name:
        if domain_name == "metaworld":
            """
            variant_space = [
                VariantSpace({
                    "arm_speed": Categorical(a=[0.6, 0.8, 1.0, 1.2, 1.4]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[1.0]),
                }),
                VariantSpace({
                    "arm_speed": Categorical(a=[1.0]),
                    "wind_xspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "wind_yspeed": Categorical(a=[1.0]),
                }),
                VariantSpace({
                    "arm_speed": Categorical(a=[1.0]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2])
                }),
            ]
            """
            variant_space = [
                VariantSpace({
                    "arm_speed": Categorical(a=[0.6, 0.8, 1.0, 1.2, 1.4]),
                    "wind_xspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "wind_yspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "goal_resistance": Categorical(a=[0,1,2])
                }),
            ]
        elif domain_name == "metaworld_complex":
            # Excluded objects do not need pull/push skills, therefore goal_resistance is not used
            goal_resistance = VariantSpace({
                "button": Categorical(a=[0,1,2]),
                "drawer": Categorical(a=[0,1,2]),
            })

            variant_space = [ # T3 
                VariantSpace({
                    "arm_speed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[1.0]),
                    "goal_resistance": goal_resistance,
                }),
            ]
            
            """
            variant_space = [ # T2
                VariantSpace({
                    "arm_speed": Categorical(a=[0.6, 0.8, 1.0, 1.2, 1.4]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[1.0]),
                    "goal_resistance": goal_resistance,
                }),
                VariantSpace({
                    "arm_speed": Categorical(a=[1.0]),
                    "wind_xspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "wind_yspeed": Categorical(a=[1.0]),
                    "goal_resistance": goal_resistance,
                }),
                VariantSpace({
                    "arm_speed": Categorical(a=[1.0]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "goal_resistance": goal_resistance,
                }),
            ]
            variant_space = [ # T1
                VariantSpace({
                    "arm_speed": Categorical(a=[0.6, 0.8, 1.0, 1.2, 1.4]),
                    "wind_xspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "wind_yspeed": Categorical(a=[0.8, 0.9, 1.0, 1.1, 1.2]),
                    "goal_resistance": goal_resistance,
                }),
            ]
            variant_space = [ 
                VariantSpace({
                    "arm_speed": Categorical(a=[1.0]),
                    "wind_xspeed": Categorical(a=[1.0]),
                    "wind_yspeed": Categorical(a=[1.0]),
                }),
            
            ]
            """

    lm, _ = clip.load("ViT-B/16")

    logger = Logger("Expert Evaluation")
    trajectory_list, frame_list = [], []

    pbar = tqdm(total=args.n_episodes)
    while len(trajectory_list) < args.n_episodes:
        # update variants
        if "variant" in env_name:
            env.update_variant_space(variant_space[len(trajectory_list)%len(variant_space)])

        trajectory, frames = run_environment_loop(
            env,
            lm,
            early_stop=True,
            video=args.video,
            verbose=args.verbose,
        )

        # Pass failed or abnormal trajectories
        if (domain_name == "metaworld" and not trajectory["infos"][-1]["success"]) or \
                (domain_name == "metaworld_complex" and trajectory["infos"][-1]["success"] != 1.0):
            print(f"continue: {trajectory['infos'][-1]['success']}")
            print(f"variant: {env.variant}")
            continue

        # Pass trajectories with specific skills, to regulate skill sequence
        if (
            env_name == "button-press-variant-v2"
            and "move_right" in trajectory["skill_langs"]
        ):
            print(f"continue??")
            continue
        if (
            env_name == "drawer-close-variant-v2"
            and "move_left" in trajectory["skill_langs"]
        ):
            print(f"continue??")
            continue

        trajectory_list.append(trajectory)
        frame_list.append(frames)

        logger.log(
            "Total Success Rate", trajectory["infos"][-1]["success"], percent=True
        )
        logger.log("Total Length", len(trajectory["actions"]))
        pbar.update(1)

    logger.print()
    pbar.close()

    save_path = osp.join(args.save_path, domain_name, env_name)
    save_dataset(trajectory_list, frame_list, save_path)


def save_dataset(trajectory_list, frame_list, save_folder):

    print_r(f"<< Saving dataset in {save_folder}... >>")

    trajectory_folder = osp.join(save_folder, "trajectory")
    video_folder = osp.join(save_folder, "video")
    make_dir(trajectory_folder)

    for episode, (trajectory, frames) in enumerate(zip(trajectory_list, frame_list)):
        with open(osp.join(trajectory_folder, f"episode_{episode}.pkl"), "wb") as f:
            pickle.dump(trajectory, f)

        if frames:
            make_dir(video_folder)
            video_path = osp.join(video_folder, f"episode_{episode}.mp4")
            save_video(video_path, frames)
